{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8be8252",
   "metadata": {},
   "outputs": [],
   "source": [
    "!uv pip install pytest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba193fd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import ast\n",
    "import sys\n",
    "import uuid\n",
    "import json\n",
    "import textwrap\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "from dataclasses import dataclass\n",
    "from typing import List, Protocol, Tuple, Dict, Optional\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "from openai import BadRequestError as _OpenAIBadRequest\n",
    "import gradio as gr\n",
    "\n",
    "load_dotenv(override=True)\n",
    "\n",
    "# --- Provider base URLs (Gemini & Groq speak OpenAI-compatible API) ---\n",
    "GEMINI_BASE = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
    "GROQ_BASE   = \"https://api.groq.com/openai/v1\"\n",
    "\n",
    "# --- API Keys (add these in your .env) ---\n",
    "openai_api_key = os.getenv(\"OPENAI_API_KEY\")   # OpenAI\n",
    "google_api_key = os.getenv(\"GOOGLE_API_KEY\")   # Gemini\n",
    "groq_api_key   = os.getenv(\"GROQ_API_KEY\")     # Groq\n",
    "\n",
    "# --- Clients ---\n",
    "openai_client = OpenAI()  # OpenAI default (reads OPENAI_API_KEY)\n",
    "gemini_client = OpenAI(api_key=google_api_key, base_url=GEMINI_BASE) if google_api_key else None\n",
    "groq_client   = OpenAI(api_key=groq_api_key,   base_url=GROQ_BASE)   if groq_api_key   else None\n",
    "\n",
    "# --- Model registry: label -> { client, model } ---\n",
    "MODEL_REGISTRY: Dict[str, Dict[str, object]] = {}\n",
    "\n",
    "def _register(label: str, client: Optional[OpenAI], model_id: str):\n",
    "    \"\"\"Add a model to the registry only if its client is configured.\"\"\"\n",
    "    if client is not None:\n",
    "        MODEL_REGISTRY[label] = {\"client\": client, \"model\": model_id}\n",
    "\n",
    "# OpenAI\n",
    "_register(\"OpenAI • GPT-5\",        openai_client, \"gpt-5\")\n",
    "_register(\"OpenAI • GPT-5 Nano\",   openai_client, \"gpt-5-nano\")\n",
    "_register(\"OpenAI • GPT-4o-mini\",  openai_client, \"gpt-4o-mini\")\n",
    "\n",
    "# Gemini (Google)\n",
    "_register(\"Gemini • 2.5 Pro\",      gemini_client, \"gemini-2.5-pro\")\n",
    "_register(\"Gemini • 2.5 Flash\",    gemini_client, \"gemini-2.5-flash\")\n",
    "\n",
    "# Groq\n",
    "_register(\"Groq • Llama 3.1 8B\",   groq_client,   \"llama-3.1-8b-instant\")\n",
    "_register(\"Groq • Llama 3.3 70B\",  groq_client,   \"llama-3.3-70b-versatile\")\n",
    "_register(\"Groq • GPT-OSS 20B\",    groq_client,   \"openai/gpt-oss-20b\")\n",
    "_register(\"Groq • GPT-OSS 120B\",   groq_client,   \"openai/gpt-oss-120b\")\n",
    "\n",
    "DEFAULT_MODEL = next(iter(MODEL_REGISTRY.keys()), None)\n",
    "\n",
    "print(f\"Providers configured → OpenAI:{bool(openai_api_key)}  Gemini:{bool(google_api_key)}  Groq:{bool(groq_api_key)}\")\n",
    "print(\"Models available     →\", \", \".join(MODEL_REGISTRY.keys()) or \"None (add API keys in .env)\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5d6b0f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CompletionClient(Protocol):\n",
    "    \"\"\"Any LLM client provides a .complete() method using a registry label.\"\"\"\n",
    "    def complete(self, *, model_label: str, system: str, user: str) -> str: ...\n",
    "\n",
    "\n",
    "def _extract_code_or_text(s: str) -> str:\n",
    "    \"\"\"Prefer fenced python if present; otherwise return raw text.\"\"\"\n",
    "    m = re.search(r\"```(?:python)?\\s*(.*?)```\", s, flags=re.S | re.I)\n",
    "    return m.group(1).strip() if m else s.strip()\n",
    "\n",
    "\n",
    "class MultiModelChatClient:\n",
    "    \"\"\"Routes requests to the right provider/client based on model label.\"\"\"\n",
    "    def __init__(self, registry: Dict[str, Dict[str, object]]):\n",
    "        self._registry = registry\n",
    "\n",
    "    def _call(self, *, client: OpenAI, model_id: str, system: str, user: str) -> str:\n",
    "        params = {\n",
    "            \"model\": model_id,\n",
    "            \"messages\": [\n",
    "                {\"role\": \"system\", \"content\": system},\n",
    "                {\"role\": \"user\",   \"content\": user},\n",
    "            ],\n",
    "        }\n",
    "        resp = client.chat.completions.create(**params)  # do NOT send temperature for strict providers\n",
    "        text = (resp.choices[0].message.content or \"\").strip()\n",
    "        return _extract_code_or_text(text)\n",
    "\n",
    "    def complete(self, *, model_label: str, system: str, user: str) -> str:\n",
    "        if model_label not in self._registry:\n",
    "            raise ValueError(f\"Unknown model label: {model_label}\")\n",
    "        info   = self._registry[model_label]\n",
    "        client = info[\"client\"]\n",
    "        model  = info[\"model\"]\n",
    "        try:\n",
    "            return self._call(client=client, model_id=str(model), system=system, user=user)\n",
    "        except _OpenAIBadRequest as e:\n",
    "            # Providers may reject stray params; we don't send any, but retry anyway.\n",
    "            if \"temperature\" in str(e).lower():\n",
    "                return self._call(client=client, model_id=str(model), system=system, user=user)\n",
    "            raise\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31558bf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass(frozen=True)\n",
    "class SymbolInfo:\n",
    "    kind: str      # \"function\" | \"class\" | \"method\"\n",
    "    name: str\n",
    "    signature: str\n",
    "    lineno: int\n",
    "\n",
    "class PublicAPIExtractor:\n",
    "    \"\"\"Extract concise 'public API' summary from a Python module.\"\"\"\n",
    "    def extract(self, source: str) -> List[SymbolInfo]:\n",
    "        tree = ast.parse(source)\n",
    "        out: List[SymbolInfo] = []\n",
    "        for node in tree.body:\n",
    "            if isinstance(node, ast.FunctionDef) and not node.name.startswith(\"_\"):\n",
    "                out.append(SymbolInfo(\"function\", node.name, self._sig(node), node.lineno))\n",
    "            elif isinstance(node, ast.ClassDef) and not node.name.startswith(\"_\"):\n",
    "                out.append(SymbolInfo(\"class\", node.name, node.name, node.lineno))\n",
    "                for sub in node.body:\n",
    "                    if isinstance(sub, ast.FunctionDef) and not sub.name.startswith(\"_\"):\n",
    "                        out.append(SymbolInfo(\"method\",\n",
    "                                              f\"{node.name}.{sub.name}\",\n",
    "                                              self._sig(sub),\n",
    "                                              sub.lineno))\n",
    "        return sorted(out, key=lambda s: (s.kind, s.name.lower(), s.lineno))\n",
    "\n",
    "    def _sig(self, fn: ast.FunctionDef) -> str:\n",
    "        args = [a.arg for a in fn.args.args]\n",
    "        if fn.args.vararg:\n",
    "            args.append(\"*\" + fn.args.vararg.arg)\n",
    "        args.extend(a.arg + \"=?\" for a in fn.args.kwonlyargs)\n",
    "        if fn.args.kwarg:\n",
    "            args.append(\"**\" + fn.args.kwarg.arg)\n",
    "        ret = \"\"\n",
    "        if fn.returns is not None:\n",
    "            try:\n",
    "                ret = f\" -> {ast.unparse(fn.returns)}\"\n",
    "            except Exception:\n",
    "                pass\n",
    "        return f\"def {fn.name}({', '.join(args)}){ret}:\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3aeadedc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class PromptBuilder:\n",
    "    \"\"\"Builds deterministic prompts for pytest generation.\"\"\"\n",
    "    SYSTEM = (\n",
    "        \"You are a senior Python engineer. Produce a single, self-contained pytest file.\\n\"\n",
    "        \"Rules:\\n\"\n",
    "        \"- Output only Python test code (no prose, no markdown fences).\\n\"\n",
    "        \"- Use plain pytest tests (functions), no classes unless unavoidable.\\n\"\n",
    "        \"- Deterministic: avoid network/IO; seed randomness if used.\\n\"\n",
    "        \"- Import the target module by module name only.\\n\"\n",
    "        \"- Cover every public function and method with at least one tiny test.\\n\"\n",
    "        \"- Prefer straightforward, fast assertions.\\n\"\n",
    "    )\n",
    "\n",
    "    def build_user(self, *, module_name: str, source: str, symbols: List[SymbolInfo]) -> str:\n",
    "        summary = \"\\n\".join(f\"- {s.kind:<6}  {s.signature}\" for s in symbols) or \"- (no public symbols)\"\n",
    "        return textwrap.dedent(f\"\"\"\n",
    "        Create pytest tests for module `{module_name}`.\n",
    "\n",
    "        Public API Summary:\n",
    "        {summary}\n",
    "\n",
    "        Constraints:\n",
    "        - Import as: `import {module_name} as mod`\n",
    "        - Keep tests tiny, fast, and deterministic.\n",
    "\n",
    "        Full module source (for reference):\n",
    "        # --- BEGIN SOURCE {module_name}.py ---\n",
    "        {source}\n",
    "        # --- END SOURCE ---\n",
    "        \"\"\").strip()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a45ac5be",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _ensure_header_and_import(code: str, module_name: str) -> str:\n",
    "    \"\"\"Ensure tests import pytest and the target module as 'mod'.\"\"\"\n",
    "    code = code.strip()\n",
    "    needs_pytest = \"import pytest\" not in code\n",
    "    has_mod = (f\"import {module_name} as mod\" in code) or (f\"from {module_name} import\" in code)\n",
    "    needs_import = not has_mod\n",
    "\n",
    "    header = []\n",
    "    if needs_pytest:\n",
    "        header.append(\"import pytest\")\n",
    "    if needs_import:\n",
    "        header.append(f\"import {module_name} as mod\")\n",
    "\n",
    "    return (\"\\n\".join(header) + \"\\n\\n\" + code) if header else code\n",
    "\n",
    "\n",
    "def build_module_name_from_path(path: str) -> str:\n",
    "    return Path(path).stem\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "787e58b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestGenerator:\n",
    "    \"\"\"Extraction → prompt → model → polish.\"\"\"\n",
    "    def __init__(self, llm: CompletionClient):\n",
    "        self._llm = llm\n",
    "        self._extractor = PublicAPIExtractor()\n",
    "        self._prompts = PromptBuilder()\n",
    "\n",
    "    def generate_tests(self, model_label: str, module_name: str, source: str) -> str:\n",
    "        symbols = self._extractor.extract(source)\n",
    "        user = self._prompts.build_user(module_name=module_name, source=source, symbols=symbols)\n",
    "        raw = self._llm.complete(model_label=model_label, system=self._prompts.SYSTEM, user=user)\n",
    "        return _ensure_header_and_import(raw, module_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8402f62f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_pytest_summary(output: str) -> Tuple[str, Dict[str, int]]:\n",
    "    \"\"\"\n",
    "    Parse the final summary line like:\n",
    "      '3 passed, 1 failed, 2 skipped in 0.12s'\n",
    "    Return (summary_line, counts_dict).\n",
    "    \"\"\"\n",
    "    summary_line = \"\"\n",
    "    for line in output.strip().splitlines()[::-1]:  # scan from end\n",
    "        if \" passed\" in line or \" failed\" in line or \" error\" in line or \" skipped\" in line or \" deselected\" in line:\n",
    "            summary_line = line.strip()\n",
    "            break\n",
    "\n",
    "    counts = {\"passed\": 0, \"failed\": 0, \"errors\": 0, \"skipped\": 0, \"xfail\": 0, \"xpassed\": 0}\n",
    "    m = re.findall(r\"(\\d+)\\s+(passed|failed|errors?|skipped|xfailed|xpassed)\", summary_line)\n",
    "    for num, kind in m:\n",
    "        if kind.startswith(\"error\"):\n",
    "            counts[\"errors\"] += int(num)\n",
    "        elif kind == \"passed\":\n",
    "            counts[\"passed\"] += int(num)\n",
    "        elif kind == \"failed\":\n",
    "            counts[\"failed\"] += int(num)\n",
    "        elif kind == \"skipped\":\n",
    "            counts[\"skipped\"] += int(num)\n",
    "        elif kind == \"xfailed\":\n",
    "            counts[\"xfail\"] += int(num)\n",
    "        elif kind == \"xpassed\":\n",
    "            counts[\"xpassed\"] += int(num)\n",
    "\n",
    "    return summary_line or \"(no summary line found)\", counts\n",
    "\n",
    "\n",
    "def run_pytest_on_snippet(module_name: str, module_code: str, tests_code: str) -> Tuple[str, str]:\n",
    "    \"\"\"\n",
    "    Create an isolated temp workspace, write module + tests, run pytest,\n",
    "    and return (human_summary, full_cli_output).\n",
    "    \"\"\"\n",
    "    if not module_name or not module_code.strip() or not tests_code.strip():\n",
    "        return \"❌ Provide module name, module code, and tests.\", \"\"\n",
    "\n",
    "    run_id = uuid.uuid4().hex[:8]\n",
    "    base = Path(\".pytest_runs\") / f\"run_{run_id}\"\n",
    "    tests_dir = base / \"tests\"\n",
    "    tests_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Write module and tests\n",
    "    (base / f\"{module_name}.py\").write_text(module_code, encoding=\"utf-8\")\n",
    "    (tests_dir / f\"test_{module_name}.py\").write_text(tests_code, encoding=\"utf-8\")\n",
    "\n",
    "    # Run pytest with this temp dir on PYTHONPATH\n",
    "    env = os.environ.copy()\n",
    "    env[\"PYTHONPATH\"] = str(base) + os.pathsep + env.get(\"PYTHONPATH\", \"\")\n",
    "    cmd = [sys.executable, \"-m\", \"pytest\", \"-q\"]  # quiet output, but still includes summary\n",
    "    proc = subprocess.run(cmd, cwd=base, env=env, text=True, capture_output=True)\n",
    "\n",
    "    full_out = (proc.stdout or \"\") + (\"\\n\" + proc.stderr if proc.stderr else \"\")\n",
    "    summary_line, counts = _parse_pytest_summary(full_out)\n",
    "\n",
    "    badges = []\n",
    "    for key in (\"passed\", \"failed\", \"errors\", \"skipped\", \"xpassed\", \"xfail\"):\n",
    "        val = counts.get(key, 0)\n",
    "        if val:\n",
    "            badges.append(f\"**{key}: {val}**\")\n",
    "    badges = \"  •  \".join(badges) if badges else \"no tests collected?\"\n",
    "\n",
    "    human = f\"{summary_line}\\n\\n{badges}\"\n",
    "    return human, full_out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d240ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "LLM = MultiModelChatClient(MODEL_REGISTRY)\n",
    "SERVICE = TestGenerator(LLM)\n",
    "\n",
    "def generate_from_code(model_label: str, module_name: str, code: str, save: bool, out_dir: str) -> Tuple[str, str]:\n",
    "    if not model_label or model_label not in MODEL_REGISTRY:\n",
    "        return \"\", \"❌ Pick a model (or add API keys for providers in .env).\"\n",
    "    if not module_name.strip():\n",
    "        return \"\", \"❌ Please provide a module name.\"\n",
    "    if not code.strip():\n",
    "        return \"\", \"❌ Please paste some Python code.\"\n",
    "\n",
    "    tests_code = SERVICE.generate_tests(model_label=model_label, module_name=module_name.strip(), source=code)\n",
    "    saved = \"\"\n",
    "    if save:\n",
    "        out = Path(out_dir or \"tests\")\n",
    "        out.mkdir(parents=True, exist_ok=True)\n",
    "        out_path = out / f\"test_{module_name}.py\"\n",
    "        out_path.write_text(tests_code, encoding=\"utf-8\")\n",
    "        saved = f\"✅ Saved to {out_path}\"\n",
    "    return tests_code, saved\n",
    "\n",
    "\n",
    "def generate_from_file(model_label: str, file_obj, save: bool, out_dir: str) -> Tuple[str, str]:\n",
    "    if file_obj is None:\n",
    "        return \"\", \"❌ Please upload a .py file.\"\n",
    "    code = file_obj.decode(\"utf-8\")\n",
    "    module_name = build_module_name_from_path(\"uploaded_module.py\")\n",
    "    return generate_from_code(model_label, module_name, code, save, out_dir)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e1401a",
   "metadata": {},
   "outputs": [],
   "source": [
    "EXAMPLE_CODE = \"\"\"\\\n",
    "def add(a: int, b: int) -> int:\n",
    "    return a + b\n",
    "\n",
    "def divide(a: float, b: float) -> float:\n",
    "    if b == 0:\n",
    "        raise ZeroDivisionError(\"b must be non-zero\")\n",
    "    return a / b\n",
    "\n",
    "class Counter:\n",
    "    def __init__(self, start: int = 0):\n",
    "        self.value = start\n",
    "\n",
    "    def inc(self, by: int = 1):\n",
    "        self.value += by\n",
    "        return self.value\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f802450e",
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks(title=\"PyTest Generator\") as ui:\n",
    "    gr.Markdown(\n",
    "        \"## 🧪 PyTest Generator (Week 4 • Community Contribution)\\n\"\n",
    "        \"Generate **minimal, deterministic** pytest tests from a Python module using your chosen model/provider.\"\n",
    "    )\n",
    "\n",
    "    with gr.Row(equal_height=True):\n",
    "        # LEFT: inputs (module code)\n",
    "        with gr.Column(scale=6):\n",
    "            with gr.Row():\n",
    "                model_dd = gr.Dropdown(\n",
    "                    list(MODEL_REGISTRY.keys()),\n",
    "                    value=DEFAULT_MODEL,\n",
    "                    label=\"Model (OpenAI, Gemini, Groq)\"\n",
    "                )\n",
    "                module_name_tb = gr.Textbox(\n",
    "                    label=\"Module name (used in `import <name> as mod`)\",\n",
    "                    value=\"mymodule\"\n",
    "                )\n",
    "            code_in = gr.Code(\n",
    "                label=\"Python module code\",\n",
    "                language=\"python\",\n",
    "                lines=24,\n",
    "                value=EXAMPLE_CODE\n",
    "            )\n",
    "            with gr.Row():\n",
    "                save_cb = gr.Checkbox(label=\"Also save generated tests to /tests\", value=True)\n",
    "                out_dir_tb = gr.Textbox(label=\"Output folder\", value=\"tests\")\n",
    "            gen_btn = gr.Button(\"Generate tests\", variant=\"primary\")\n",
    "\n",
    "        # RIGHT: outputs (generated tests + pytest run)\n",
    "        with gr.Column(scale=6):\n",
    "            tests_out = gr.Code(label=\"Generated tests (pytest)\", language=\"python\", lines=24)\n",
    "            with gr.Row():\n",
    "                run_btn = gr.Button(\"Run PyTest\", variant=\"secondary\")\n",
    "            summary_md = gr.Markdown()\n",
    "            full_out = gr.Textbox(label=\"Full PyTest output\", lines=12)\n",
    "\n",
    "    # --- events ---\n",
    "\n",
    "    def _on_gen(model_label, name, code, save, outdir):\n",
    "        tests, msg = generate_from_code(model_label, name, code, save, outdir)\n",
    "        status = msg or \"✅ Done\"\n",
    "        return tests, status\n",
    "\n",
    "    gen_btn.click(\n",
    "        _on_gen,\n",
    "        inputs=[model_dd, module_name_tb, code_in, save_cb, out_dir_tb],\n",
    "        outputs=[tests_out, summary_md],\n",
    "    )\n",
    "\n",
    "    def _on_run(name, code, tests):\n",
    "        summary, details = run_pytest_on_snippet(name, code, tests)\n",
    "        return summary, details\n",
    "\n",
    "    run_btn.click(\n",
    "        _on_run,\n",
    "        inputs=[module_name_tb, code_in, tests_out],\n",
    "        outputs=[summary_md, full_out],\n",
    "    )\n",
    "\n",
    "ui.launch(inbrowser=True)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-engineering",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
